"""
@Description :   Use keyword matching to filter out medical-related samples from parquet files (which is the raw dataset format in Huggingface).
                 You can modify the `filter_file` function to filter data from other sources.
@Author      :   Henrychur 
@Time        :   2023/10/11 09:44:36
"""
import os
import json
import tqdm
import pandas as pd
import multiprocessing

from keywords import *

PARQUET_FOLDER = 'es\\downloads' 
SAVE_ROOT = '.\\es\\es_medical'
LANG = 'es'

# These thresholds are set while testing on the multilingual dataset CulturalX.
THRESHOLD = {
    'en': {
        'HIT_TIMES_THRESHOLD': 5,
        'HIT_RATIO_THRESHOLD': 0.04
    },
    'zh': {
        'HIT_TIMES_THRESHOLD': 5,
        'HIT_RATIO_THRESHOLD': 0.05
    },
    'jp': {
        'HIT_TIMES_THRESHOLD': 5,
        'HIT_RATIO_THRESHOLD': 0.05
    },
    'ru': {
        'HIT_TIMES_THRESHOLD': 4,
        'HIT_RATIO_THRESHOLD': 0.02
    },
    'es': {
        'HIT_TIMES_THRESHOLD': 4,
        'HIT_RATIO_THRESHOLD': 0.04
    },
    'fr': {
        'HIT_TIMES_THRESHOLD': 4,
        'HIT_RATIO_THRESHOLD': 0.04
    },
}

HIT_TIMES_THRESHOLD = THRESHOLD[LANG]['HIT_TIMES_THRESHOLD']
HIT_RATIO_THRESHOLD = THRESHOLD[LANG]['HIT_RATIO_THRESHOLD']

class MedicalTextFilter:
    """Class to filter medical-related samples from parquet and save results to a JSON file."""
    def __init__(self):
        self.keywords = self.load_keywords()

    @staticmethod
    def load_keywords():
        """
            Load and process medical keywords for filtering.
            There are 200 medical keywords each languages, which are generated by ChatGPT
        """
        keywords = []
        keywords.extend(es_medical_biology_words) # Change this to current language
        keywords.extend(es_medical_words) # Change this to current language
        keywords = list(set(keywords))
        for i in range(len(keywords)):
            keywords[i] = keywords[i].lower()
        long_keywords, short_keywords = [], []
        for keyword in keywords:
            if len(keyword) > 10 or (len(keyword) > 5 and len(keyword.split()) > 1):
                long_keywords.append(keyword)
            else:
                short_keywords.append(keyword)
        return long_keywords, short_keywords

    def match_keywords(self, text, en_like=True):
        """
            Match keywords in the given text.
            Notice that English-like languages uses spaces as intervals between words, while Chinese-like languages not.
            So we apply different matching strategies for them.
        """
        unique_hit_cnt, hit_cnt = 0, 0
        if en_like:
            lower_text = text.lower()
            long_keywords, short_keywords = self.keywords
            for long_keyword in long_keywords:
                if long_keyword in lower_text:
                    unique_hit_cnt += 1
                    hit_cnt += lower_text.count(long_keyword) * len(long_keyword)
            if unique_hit_cnt >= HIT_TIMES_THRESHOLD and hit_cnt / len(text) >= HIT_RATIO_THRESHOLD:
                return True
            split_text = lower_text.split()
            for short_keyword in short_keywords:
                if short_keyword in split_text:
                    unique_hit_cnt += 1
                    hit_cnt += split_text.count(short_keyword) * len(short_keyword)
        else:
            for keyword in self.keywords:
                if keyword in text:
                    unique_hit_cnt += 1
                    hit_cnt += text.count(keyword) * len(keyword)
        return unique_hit_cnt >= HIT_TIMES_THRESHOLD and hit_cnt / len(text) >= HIT_RATIO_THRESHOLD

    def filter_file(self, filename):
        """Filter medical-related samples from a parquet file."""
        df = pd.read_parquet(filename)['text']
        medical_texts = []
        text_cnt, medical_text_cnt = 0, 0
        for text in tqdm.tqdm(df):
            if self.match_keywords(text):
                medical_texts.append(text)
                medical_text_cnt += 1
            text_cnt += 1
        print(f'{filename} has been filtered. Medical text count: {medical_text_cnt}/{text_cnt} \tRatio: {medical_text_cnt/text_cnt*100:.2f}%')
        result = {
            'text_cnt': text_cnt,
            'medical_text_cnt': medical_text_cnt,
            'medical_text_ratio': medical_text_cnt / text_cnt,
            'HIT_TIMES_THRESHOLD': HIT_TIMES_THRESHOLD,
            'HIT_RATIO_THRESHOLD': HIT_RATIO_THRESHOLD,
            'medical_texts': medical_texts
        }
        save_path = os.path.join(SAVE_ROOT, filename.split('\\')[-1].split('.')[0] + '.json')
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=4)

    def filter_data(self):
        """Filter data across multiple files."""
        filenames_candidates = os.listdir(PARQUET_FOLDER)
        filenames = [filename for filename in filenames_candidates if os.path.getsize(os.path.join(PARQUET_FOLDER, filename)) > 1 * 1024 * 1024 * 1024 and not os.path.exists(os.path.join(self.SAVE_ROOT, filename.split('.')[0] + '.json'))]
        print(f'Total {len(filenames)} files need to be filtered.')
        pool = multiprocessing.Pool(processes=32)
        pool.map(self.filter_file, filenames)
        pool.close()
        pool.join()


if __name__ == "__main__":
    filter = MedicalTextFilter()
    filter.filter_data()
